In [44]:
USE_CUDA = False  # Enable to use cudf instead of pandas
VERBOSE = False  # Additonal logging
EXTENTED_VIEW = False  # view additional data like model importances (SHAP is preferable due to reasons)
In [45]:
import importlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import chi2_contingency, pointbiserialr
from scipy.stats import normaltest, skew, kurtosis
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
import shared.graph as graph
import shared.pipeline as pipeline_runner
import shared.stats_utils as stats_utils
import shared.utils as utils
import src.data_loader as data_loader
from src import model_config
import shared.definitions as model_defs
from shared import ml_config_core
import time
from typing import Dict
from shared.pipeline import ModelTrainingResult
import shap
import xgboost
from scipy.stats import spearmanr
from scipy.stats import kruskal
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
from IPython.display import display, Markdown

if USE_CUDA:
    pass

    %load_ext cudf.pandas
In [46]:
# features.dtypes
In [47]:
utils.pandas_config(pd)
utils.plt_config(plt)

sns.set_theme(style="darkgrid", palette="pastel")
plt.style.use("fivethirtyeight")

Module 3, Sprint 2, Stroke Prediction Dataset¶

Introduction and Goals¶

Assuming that if we can accurately predict which individuals have a higher risk of strokes negative outcomes can be eliminated or mitigated through lifestyle changes, additional testing, monitoring etc. we still need to balance the costs such additional services would be required with their potential benefits.

So we want to identify the highest possible number of potential strokes (i.e. maximize recall) while still mantaining high precision. This would suggest that we should use optimize our models based on the F1 score for the stroke = 1 class. However, we also need to take into account the costs associated with either outcome:

  • false positive: cost medium, additional tests and other possibly services will be provided "unnecessarily" to individuals who are have a low risk.
  • false negative: very high costs, would require immediate hospitalization and might result in death

So we can tolerate a much higher proportion of false positives than false negatives. The exact ratio would depend on a more in depth cost analysis (which could performed by healtchare and insurance providers).

Therefore, in our analysis we'll focus on maximizing the recall/class accuracy for stroke = 1 (the proportion of false positive should of course still be considered and minimized as a secondary target).

Core Assumptions:¶

  • The cost of a false positive is higher than the cost of a false negative
  • Risk factors that might increase the likelihood of a stroke significantly also affect other health issues with a high mortality rate
    • i.e. people who are at very high risk of stroke are likely to have a lower life expectancy and die before they had a chance to have one. This would likely mean that our model would asign lower significance/importance to factors such as having a heart disease, being obese, having diabetes etc. while at the same time over estimating the effect of factors which are less correlated to other diseases.
  • The dataset is likely not representative and some overfitting is unavoidable

EDA & Model¶

Our ideal baseline would be the "simple" algorithms used by doctors and healtcare providers based on risk factors such as:

  • age
  • blood pressure
  • etc.

One important aspect to consider is that maximizing the overall performance of the model. Classifyinga "high-risk" individual as a "low-risk" carries a much bigger cost than doing the opposite.

Therefore we'll use two metrics when tunning our model:

  • macro f1 score
  • accuracy for the minority stroke class

1.1 Analysis of individuals features and their distributions¶

In [48]:
importlib.reload(data_loader)

source_df, eda_df, model_df, features, labels = data_loader.load_data()
In [49]:
eda_df_ext = eda_df.copy()
eda_df_ext["age_binned"] = pd.cut(
    eda_df_ext["age"], [0, 12, 18, 30, 40, 50, 60, 70, 80, 100], right=True
)

# Binning:
# < 80 - low
# 80 - 100 - normal
# 100-125 elevated
# > 125 high


eda_df_ext["avg_glucose_level_binned"] = pd.cut(
    eda_df_ext["avg_glucose_level"], [0, 80, 100, 125, 999], right=True
)

# Binning:
# < 18.5 - under
# 18.5 - 25 - normal
# 25 - 30 elevated
# > 30 obese

eda_df_ext["bmi_binned"] = pd.cut(eda_df_ext["bmi"], [0, 18.5, 25, 30, 999], right=True)
eda_df_ext["bmi_binned_cats"] = eda_df_ext["bmi_binned"]
eda_df_ext["bmi_binned_cats"] = eda_df_ext["bmi_binned_cats"].cat.rename_categories(
    ["Underweight", "Normal", "Overweight", "Obese"]
)
eda_df = eda_df[eda_df["gender"] != "Other"]
In [50]:
source_df_no_cat = eda_df.copy()
In [51]:
if VERBOSE:
    counts = eda_df["stroke"].value_counts()
    display(
        pd.DataFrame(
            {
                "Counts": counts,
                "Percentage": (counts / counts.sum()).mul(100).round(2).astype(str)
                + "%",
            }
        )
    )
In [52]:
if VERBOSE:
    nan_counts = source_df.isna().sum()
    nan_counts = nan_counts[nan_counts > 0]
    if len(nan_counts) > 0:
        display("NaN Values:")
        display(nan_counts)
    # There is no missing data

The charts below show the distribution of all the features included in the dataset:

  1. Numerical features are displayed using a KDE and Boxen plots with additional testing for normality.
  2. Value counts are show for non-numerical features
In [53]:
scaler = StandardScaler()
source_df_scaled = source_df_no_cat.select_dtypes(exclude=["object"])


def clean_tick_label(v):
    label = str(v).split(" ")[0]
    label = label.replace("_", " ")
    label = label.title()

    label = "Yes" if label == "1" else label
    label = "No" if label == "0" else label

    return label


for i, variable in enumerate(source_df_no_cat.columns):
    col_vals = source_df_no_cat[variable]
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))
    axes[0].set_aspect(aspect="auto", adjustable="box")
    axes[1].set_aspect(aspect="auto", adjustable="box")

    try:
        if stats_utils.is_non_numerical_or_discrete(col_vals):
            unique_values = col_vals.unique()

            counts = source_df_no_cat[variable].value_counts()
            count_plot_df = source_df_no_cat[
                source_df_no_cat[variable].isin(counts[counts >= 50].index)
            ]
            count_plot_df[variable] = count_plot_df[variable].astype("str")
            sns.countplot(x=variable, data=count_plot_df, ax=axes[0], width=0.55)

            axes[0].set_title("Frequency Distribution", fontsize=12)
            tick_labels = [
                clean_tick_label(v) for v in list(count_plot_df[variable].unique())
            ]

            axes[0].set_xticklabels(tick_labels)
            axes[0].set_xlabel("")
            y_limit = ((len(source_df_no_cat) // 500) + 1) * 500
            axes[0].set_ylim([0, y_limit])

            proportions = col_vals.value_counts(normalize=True)
            explode = [0.02] * len(proportions)
            wedges, _ = axes[1].pie(
                proportions,
                labels=None,
                autopct=None,
                startangle=140,
                wedgeprops=dict(width=0.3),
                explode=explode,
            )

            for wedge, label in zip(wedges, proportions.index):
                angle = (wedge.theta2 + wedge.theta1) / 2
                x = np.cos(np.deg2rad(angle))
                y = np.sin(np.deg2rad(angle))
                horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))]
                connectionstyle = "angle,angleA=0,angleB={}".format(angle)
                pct = round(np.round(wedge.theta2 - wedge.theta1) / 360 * 100, 1)
                label_tr = clean_tick_label(label)
                # label_tr = str(label).replace("_", " ").title()
                axes[1].annotate(
                    f"{label_tr}: {pct}%",
                    xy=(x / 2, y / 2),
                    xytext=(1.15 * x, 1.15 * y),
                    arrowprops=dict(arrowstyle="-", connectionstyle=connectionstyle),
                    horizontalalignment=horizontalalignment,
                )

        else:
            sns.kdeplot(x=variable, data=source_df_no_cat, ax=axes[0], fill=True)
            axes[0].set_title("Original Data Distribution (KDE)", fontsize=12)
            axes[0].set_ylabel("")
            axes[0].set_yticks([])

            # Statistical Annotations on KDE Plot
            mean_val = col_vals.mean()
            std_dev = col_vals.std()
            stat, p_val = normaltest(col_vals)
            normality = "Normal" if p_val > 0.05 else "Not normal"
            axes[0].text(
                0.95,
                0.8,
                f"{normality}\nP-val: {round(p_val, 3)}\nMean: {mean_val:.2f}\nStd Dev: {std_dev:.2f}",
                ha="right",
                va="center",
                transform=axes[0].transAxes,
                fontsize=9,
                bbox=dict(
                    boxstyle="round,pad=0.3", edgecolor="blue", facecolor="white"
                ),
            )

            sns.boxenplot(
                x=variable,
                data=source_df_no_cat,
                ax=axes[1],
                width=0.4,
            )
            axes[1].set_yticks([])

            Q3 = source_df_no_cat[variable].quantile(0.75)
            Q1 = source_df_no_cat[variable].quantile(0.25)
            IQR = Q3 - Q1
            upper_whisker = Q3 + 1.5 * IQR
            std_dev = source_df_no_cat[variable].std()

            upper_bound = upper_whisker + 0.5 * std_dev
            axes[1].set_xlim(0, min(upper_bound, source_df_no_cat[variable].max()))

            skewness = skew(col_vals)
            excess_kurtosis = kurtosis(col_vals)
            axes[1].text(
                0.95,
                0.85,
                f"Skew: {skewness:.2f}\nKurt: {excess_kurtosis:.2f}",
                ha="right",
                va="center",
                transform=axes[1].transAxes,
                fontsize=9,
                bbox=dict(
                    boxstyle="round,pad=0.3", edgecolor="blue", facecolor="white"
                ),
            )

        for ax in axes:
            ax.set_title(ax.get_title(), fontsize=12)
            ax.set_xlabel(ax.get_xlabel(), fontsize=10)
            ax.set_ylabel(ax.get_ylabel(), fontsize=10)

        title = variable.replace("_", " ").title()
        if title == "Bmi":
            title = "BMI"
        plt.suptitle(title, fontsize=16, y=1.02)

        plt.tight_layout()
        plt.show()

    except Exception as ex:
        plt.close(fig)
        raise ex
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
No description has been provided for this image
No description has been provided for this image
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
No description has been provided for this image
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
No description has been provided for this image
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
No description has been provided for this image
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
No description has been provided for this image
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
No description has been provided for this image
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
No description has been provided for this image

1.2 Relationships Between Features¶

In [54]:
def correlation_test(x, y):
    # Convert 'category' dtype to 'object' for unified handling
    if x.dtype.name == "category":
        x = x.astype("object")
    if y.dtype.name == "category":
        y = y.astype("object")

    def chi_squared_test(x, y):
        contingency_table = pd.crosstab(x, y)
        chi2, p, _, _ = chi2_contingency(contingency_table)
        n = np.sum(contingency_table.values)
        k, r = contingency_table.shape
        cramers_v = np.sqrt(chi2 / (n * min(k - 1, r - 1)))
        return cramers_v, p

    if x.dtype == "object" or y.dtype == "object":
        return chi_squared_test(x, y)
    elif x.dtype == "bool" and y.dtype in ["int64", "float64"]:
        return pointbiserialr(x, y)
    elif y.dtype == "bool" and x.dtype in ["int64", "float64"]:
        return pointbiserialr(y, x)
    elif x.dtype in ["int64", "float64"] and y.dtype in ["int64", "float64"]:
        return spearmanr(x, y)
    else:
        raise ValueError(
            f"Unsupported data types for correlation test: {x.dtype} and {y.dtype}"
        )


corr = pd.DataFrame(index=eda_df.columns, columns=eda_df.columns)
p_values = pd.DataFrame(index=eda_df.columns, columns=eda_df.columns)

for col1 in eda_df.columns:
    for col2 in eda_df.columns:
        if col1 != col2:
            corr_value, p_value = correlation_test(eda_df[col1], eda_df[col2])
            corr.loc[col1, col2] = corr_value
            p_values.loc[col1, col2] = p_value
        else:
            # Handle the diagonal (correlation with self)
            corr.loc[col1, col2] = 1.0
            p_values.loc[col1, col2] = 0.0

mask = np.triu(np.ones_like(corr, dtype=bool))

f, ax = plt.subplots(figsize=(13, 11))

cmap = sns.diverging_palette(230, 20, as_cmap=True)
corr = round(corr.applymap(pd.to_numeric), 2)

significant_mask = np.abs(corr) >= 0.1
combined_mask = mask | ~significant_mask


def format_annotation(corr_value, p_value):
    if p_value < 0.05:
        return f"{corr_value:.2f}*"
    elif abs(corr_value) > 0.35:
        return f"{corr_value:.2f}\np={p_value:.2f}"
    return ""


vectorized_formatter = np.vectorize(format_annotation)

sns.heatmap(
    corr,
    mask=combined_mask,
    cmap=cmap,
    vmax=1,
    vmin=-1,
    center=0,
    square=True,
    linewidths=0.5,
    cbar_kws={"shrink": 0.5},
    annot=vectorized_formatter(corr.to_numpy(), p_values.to_numpy()),
    fmt="",
)

plt.title("Correlation bet variable pairs")
plt.annotate(
    "* p < 0.05\nonly columns where correlation is > 0.1 shown (",
    xy=(0.5, -0.175),
    xycoords="axes fraction",
    xytext=(0, -40),
    textcoords="offset points",
    ha="center",
    va="top",
)
plt.show()
invalid value encountered in format_annotation (vectorized)
No description has been provided for this image
In [55]:
if VERBOSE:
    display(corr)
Because the datatypes of features vary we had to use different methods to measure the strength and significance of each pair:

- Chi-Squared Test: Assesses independence between two categorical variables.  For bool-bool pairs due to categorical nature.

- Point Biserial Correlation: Measures correlation between a binary and a continuous variable. For bool-numerical pairs to account for mixed data types.

- Spearman's Rank Correlation: Assesses monotonic relationship between two continuous variables. Used for numerical-numerical pairs (for non-normally distributed data).

Since the Chi-Squared test outputs an unbound statistic/value which can't be directly compared to  pointbiserialr or Spearman Rank we have converted them to a  ` Cramér's V:` value which is normalized between 0 and 1. This was done to make the values in the matrix more uniform however we must note that Cramér's V and Spearman's correlation coefficients are fundamentally different statistics and generally can't be directly compared.
In [56]:
def draw_distribution_pie_charts(split_var="gender", include_cols=None):
    if include_cols is None:
        include_cols = ["ever_married", "work_type", "Residence_type", "smoking_status", "bmi_binned_cats"]

    ii_empl_df = eda_df_ext[[split_var, *include_cols]]

    fig, axes = plt.subplots(len(include_cols), 2, figsize=(12, len(include_cols)*5))

    for i, column in enumerate(include_cols):
        for j, target in enumerate(ii_empl_df[split_var].unique()):
            data = ii_empl_df[ii_empl_df[split_var] == target][column].value_counts()
            pie_labels = [
                f"{index}" for index, pct in zip(data.index, data * 100 / data.sum())
            ]
            axes[i, j].set_title(f"{column} for {target}", fontdict={"fontsize": 12})
            
            # proportions = col_vals.value_counts(normalize=True)
            explode = [0.02] * len(data)
            wedges, _ = axes[i, j].pie(
                data,
                labels=None,
                autopct=None,
                startangle=140,
                wedgeprops=dict(width=0.3),
                explode=explode,
            )
            # axes[i, j].pie(data, labels=pie_labels, autopct="%1.1f%%", startangle=90)

            # 
            for wedge, label in zip(wedges, data.index):
                angle = (wedge.theta2 + wedge.theta1) / 2
                x = np.cos(np.deg2rad(angle))
                y = np.sin(np.deg2rad(angle))
                horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))]
                connectionstyle = "angle,angleA=0,angleB={}".format(angle)
                pct = round(np.round(wedge.theta2 - wedge.theta1) / 360 * 100, 1)
                label_tr = clean_tick_label(label)
                # label_tr = str(label).replace("_", " ").title()
                axes[i, j].annotate(
                    f"{label_tr}: {pct}%",
                    xy=(x / 2, y / 2),
                    xytext=(1.15 * x, 1.15 * y),
                    arrowprops=dict(arrowstyle="-", connectionstyle=connectionstyle),
                    horizontalalignment=horizontalalignment,
                )

            
    plt.suptitle(split_var.replace("_", " ").title(), fontsize=16, y=1.02)

    plt.tight_layout()
    plt.show()


# TODO: MAKE SURE THE COLORS ALWAYS MATCH BY LABEL BETWEEN PLOTS!
In [57]:
draw_distribution_pie_charts(
    split_var="stroke",
    include_cols=["smoking_status", "heart_disease", "bmi_binned_cats", "ever_married"]
)
No description has been provided for this image
In [58]:
draw_distribution_pie_charts(
    split_var="heart_disease",
    include_cols=["stroke", "smoking_status"]
)
No description has been provided for this image
In [59]:
sns.displot(
    data=eda_df_ext,
    x="age",
    hue="bmi_binned_cats",
    kind="kde",
    height=6,
    multiple="fill",
    clip=(10, 80),
)
plt.title("Weight and Age", x=0.5, y=1.025, fontdict={"size": 16})
Out[59]:
Text(0.5, 1.025, 'Weight and Age')
No description has been provided for this image
In [60]:
def boxen_plot_by_cat(c, y_target):
    _df = eda_df_ext.copy()
    _df = _df[_df[target_y].notna()]
    grouped = _df.groupby(c)[y_target]

    stat, p_value = kruskal(*[group for name, group in grouped])
    test_explain = f"Kruskal-Wallis Test for {c} vs {y_target}: p-value = {p_value:.3f}"

    if p_value < 0.05:
        group_counts = _df.groupby(c).size()
        log_base = 2
        max_width = 0.8

        log_widths = np.log(group_counts + 1) / np.log(log_base)
        normalized_widths = log_widths / log_widths.max()
        scaled_widths = normalized_widths * max_width
        min_width = 0.05
        final_widths = scaled_widths.clip(min_width)

        plt.figure(figsize=(9, 6))
        for group in normalized_widths.index:
            sns.boxenplot(
                data=_df[_df[c] == group],
                x=c,
                y=y_target,
                color="b",
                width=final_widths[group],
            )
        plt.title(f"{' '.join(c.split('_')).title()}\n", fontdict={"fontsize": 18})
        plt.xticks(
            ticks=range(len(group_counts)),
            labels=[f"{group}\nn={count}" for group, count in group_counts.items()],
        )
        plt.annotate(
            test_explain,
            fontsize=12,
            xy=(0, -0.05),
            xycoords="axes fraction",
            xytext=(0, -40),
            textcoords="offset points",
            ha="left",
            va="top",
        )

        plt.xlabel("")
        plt.show()
    else:
        if VERBOSE:
            print(
                f"{c} vs {y_target} No significant difference found (p-value = {p_value:.3f})"
            )


for target_y in ["age", "bmi", "avg_glucose_level"]:
    for c in [
        "stroke",
        "gender",
        "age_binned",
        "bmi_binned",
        "avg_glucose_level_binned",
        "work_type",
        "Residence_type",
        "ever_married",
        "heart_disease",
        "hypertension",
    ]:
        if target_y in c and "binned" in c:
            continue
        boxen_plot_by_cat(c, target_y)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Risk Factor Analysis¶

In this part we'll look into the relationship between specific risk factors which we would assume to be signficantly related to the likelyhood of having a stroke (both based on correlation and subject knowledge):

  • age
  • hypertension
  • heart_disease
  • avg_glucose_level
  • bmi
  • smoking_status

The KDE plots show the likelihood of having a stroke at a specific age if the patient has any of the listed risk factor (the Y axis is relative to the full sample of individual with the risk factor not just people who have the condition and had a stroke)

In [61]:
risk_df = eda_df_ext[
    [
        "age",
        "hypertension",
        "heart_disease",
        "avg_glucose_level",
        "bmi",
        "stroke",
        "smoking_status",
    ]
].copy()

risk_df["high_glucose"] = risk_df["avg_glucose_level"] > 125
risk_df["overweight"] = risk_df["bmi"] > 25
risk_df["obese"] = risk_df["bmi"] > 30
risk_df["smoking_status"] = risk_df["smoking_status"] == "smokes"
In [62]:
risk_factors = [
    "hypertension",
    "heart_disease",
    "high_glucose",
    "overweight",
    "obese",
    "smoking_status",
]

risk_df["risk_factor_count_raw"] = risk_df[risk_factors].sum(axis=1)
risk_df["risk_factor_count"] = risk_df["risk_factor_count_raw"].apply(
    lambda x: "3+" if x >= 3 else str(int(x))
)
In [63]:
plt.figure(figsize=(12, 6))

sns.displot(
    data=risk_df,
    x="age",
    hue="risk_factor_count",
    kind="kde",
    height=6,
    multiple="fill",
    clip=(0, None),
)

plt.xlim([12, 82])
plt.xlabel("Age")
plt.ylabel("Density")
plt.title("Number of Risk Factors Based on Age", x=0.5, y=1.025, fontdict={"size": 16})

plt.legend()
plt.show()
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
<Figure size 1200x600 with 0 Axes>
No description has been provided for this image

We can see that the number of risk factors on average increases until the age of ~60. Afterward it start slightly decreasing. This might is likely a case of survivor bias as most of them tend to have a negative effect on life expectancy.

In [64]:
plt.figure(figsize=(12, 6))

risk_factor_counts = risk_df["risk_factor_count"].unique()
age_range = np.linspace(risk_df["age"].min(), risk_df["age"].max(), 300)

for count in risk_factor_counts:
    total_in_group = len(risk_df[risk_df["risk_factor_count"] == count])

    subset = risk_df[
        (risk_df["risk_factor_count"] == count) & (risk_df["stroke"] == 1)
    ]["age"]

    if len(subset) > 0:
        kde = gaussian_kde(subset, bw_method="silverman")
        density = kde(age_range) * len(subset) / total_in_group

        sns.lineplot(x=age_range, y=density, label=f"Risk Factors: {count}")

plt.xlim([12, 82])
plt.xlabel("Age")
plt.ylabel("Probality")
plt.title("Probability of Having A Stroke by Age", x=0.5, y=1.05, fontdict={"size": 18})

desc = "This is a KDE plot which show the estimated probality of having a stroke based on\n the number of risk factors at a given age (the density is relative to the size of the entire 'Risk Factors: N` sample)"
plt.annotate(
    desc,
    fontsize=12,
    xy=(0, 0),
    xycoords="axes fraction",
    xytext=(0, -40),
    textcoords="offset points",
    ha="left",
    va="top",
)

plt.legend()
plt.show()
No description has been provided for this image

This is chart shows individual KDE density curves for each subgroup based on age (it can be interpreted similarly to a histogram).

Interestingly the difference is most prominent below ~65, afterwards the effect of having just 1 or 2 risk factors is much lower.

We can see that people who do are not overweight, do not smoke, do not have elevated glucose levels or heart issues only have a much lower probability of having a stroke as long as they are younger than 60.

In [65]:
risk_df_bool = risk_df.apply(
    lambda col: col.astype(bool) if col.isin([0, 1]).all() else col
)
risk_df_bool["NO_RISK_FACTORS"] = ~risk_df_bool[risk_factors].any(axis=1)
In [66]:
plt.figure(figsize=(12, 6))

pallete = sns.color_palette("husl", n_colors=len(risk_factors))
for ii, risk_factor in enumerate((risk_factors + ["NO_RISK_FACTORS"])):
    total_in_group = len(risk_df_bool[risk_df_bool[risk_factor] == True])
    subset = risk_df_bool[
        (risk_df_bool[risk_factor] == True) & (risk_df_bool["stroke"] == 1)
    ]["age"]

    if len(subset) > 0:
        kde = gaussian_kde(subset, bw_method="silverman")
        density = kde(age_range) * len(subset) / total_in_group

        risk_factor_formated = risk_factor.replace("_", " ").title()

        if len(pallete) > ii:
            sns.lineplot(
                x=age_range,
                y=density,
                label=f"{risk_factor_formated}",
                color=pallete[ii],
                linewidth=2.5,
            )
        else:
            sns.lineplot(
                x=age_range,
                y=density,
                label=f"{risk_factor_formated}",
                color="grey",
                linewidth=5,
                alpha=0.5,
            )

plt.xlim([12, 82])
plt.xlabel("Age")
plt.ylabel("Density")
plt.title(
    "Probability of Having A Stroke by Individual Risk Factor",
    x=0.5,
    y=1.05,
    fontdict={"size": 18},
)

desc = "This is a KDE plot which show the estimated probality of having a stroke based on\n the number of risk factors at a given age (the density is relative to the size of the entire 'Risk Factors: N` sample)"
plt.annotate(
    desc,
    fontsize=14,
    xy=(0, 0),
    xycoords="axes fraction",
    xytext=(0, -40),
    textcoords="offset points",
    ha="left",
    va="top",
)

plt.legend()
plt.show()
No description has been provided for this image

Gennerally most of the risk factors besides having a heart disease seem to have a similar effect below the age of 60, afterwards having diabetes/etc. or hypertension have a much higher effect.

In [67]:
risk_df_t = risk_df[[*risk_factors, "stroke", "age"]].copy()
melted_df = pd.melt(risk_df_t, id_vars=["age", "stroke"], var_name="risk_factor")

1.5 PCA¶

We have attempted to use PCA to reduce the dimensionality of the dataset.

This might be necessary for datasets which include very high numbers of features. Since this specific dataset is very simple and includes a very low number of columns this was only done for informative/educational purposes.

Additionally, we have included binary/categorical variables which also is generally not advisable in real world cases.

While PCA can be used a preprocessing step (and we have expirmenting with using it for simple logistic or SVM models) this is generally not necessary for simple datasets like this.

In [67]:
 
In [68]:
importlib.reload(stats_utils)
importlib.reload(graph)

# TODO: either only look at numerical values or do MCA

_explained_pc = stats_utils.get_pca_explained(features)
graph.render_pca_component_plot(
    _explained_pc, title="PCA (only features) Cumulative Variance"
)
display(f"Total Feature Count: {len(features.columns)}")
No description has been provided for this image
'Total Feature Count: 10'

PCA was done using a Sklrean pipeline which handles standardization for numerical variables.

We can see that the dataset (not including the target variable) could effectively be reduced to 8 components (which preserves about 80% of variance) since this isn't that much lower than the total number of variables it's not particularly useful for ML or even visualization purposes.

2. ML Models¶

We have used various different models . Our process included these steps:

  1. Define separate configurations for each model based on target variables/metrics used for tunning (see src/model_config.py and shared/ml_config_core.py). We have tested these models:
  • XGBoost
  • CatBoost
  • LGBM
  • SVM
  • Random Forest
  • Custom ensemble model (log + SVM + KNN with a soft voting classifier)

Training and validation were performed using Stratified KFolds (5 folds)

  1. Hyperparameter tuning was performed for each model. Because the dataset is heavily imbalanced we have using various different target metrics:
  • macro F1
  • recall (only target class)
  • F1 (only target class)
  • Various

Builtin class weights parameters were used for all the model besides the ensemble one which uses SMOTE, ADASYN, standard oversampling etc. The results for each individual model are stored separately in .tuning_results folder.

In [69]:
importlib.reload(model_defs)
importlib.reload(pipeline_runner)
importlib.reload(model_config)
importlib.reload(ml_config_core)
model_configs = model_config.get_config()
include_models = [k for k in model_configs.keys() if model_configs[k].tune]

res = {}
cv = StratifiedKFold(n_splits=5, random_state=42, shuffle=True)

for model_name in include_models:
    cfg = model_configs[model_name].model

    start_time = time.time()

    if VERBOSE:
        print(
            f"Tunning {model_name} n_iters={cfg.search_n_iter} with:\n {cfg.param_grid}"
        )

    tunning_result = pipeline_runner.run_tunning_for_config(
        model_key=model_name,
        config=cfg,
        cv=cv,
        features=features,
        labels=labels,
    )
    tunning_result.to_yaml()
    # res[model_name] = tunning_result
    res[model_name] = tunning_result

    end_time = time.time()
    elapsed_time = round(end_time - start_time, 1)

    if VERBOSE:
        print(f"Total time:{elapsed_time}\n")

tunning_result_res_df = model_defs.TuningResult.convert_to_dataframe(res)
Using balancing config: UnderSamplingConfig
Using <class 'sklearn.model_selection._search.RandomizedSearchCV'> with n_iter=250
Using <class 'sklearn.model_selection._search.RandomizedSearchCV'> with n_iter=250
In [70]:
tunning_result_res_df
Out[70]:
best_score best_params search_type model_config_reference
model_key
XGBoostCatF1UndersampleAuto 0.191422 {'model__scale_pos_weight': 1, 'model__n_estimators': 250, 'model__min_child_weight': 1.5, 'model__max_depth': 6, 'model__learning_rate': 0.01, 'model__gamma': 0.3} Random XGBoostCatF1UndersampleAuto(model=<class 'xgboost.sklearn.XGBClassifier'>, supports_nan=True, param_grid={'model__learning_rate': [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1], 'model__max_depth': [4, 5, 6, 7, 10, 12, None], 'model__n_estimators': [50, 100, 150, 200, 250], 'model__min_child_weight': [0.1, 0.25, 0.5, 0.75, 1, 1.5, 2, 2.5, 3], 'model__gamma': [0, 0.05, 0.1, 0.3, 0.4], 'model__scale_pos_weight': [1, 5, 10, 20, 25, 30, 35, 40]}, builtin_params={'enable_categorical': True}, search_n_iter=250, balancing_config=UnderSamplingConfig(params={}), preprocessing=FunctionTransformer(func=<function preprocessing_for_xgboost.<locals>.convert_to_category at 0x7eff4b463eb0>), tunning_func_target=make_scorer(f1_score, pos_label=1), best_params=None, ensemble_classifier=None)
XGBoostTuneCatFBeta_25 0.433492 {'model__scale_pos_weight': 25, 'model__n_estimators': 250, 'model__min_child_weight': 1.5, 'model__max_depth': 4, 'model__learning_rate': 0.01, 'model__gamma': 0.1} Random XGBoostTuneCatFBeta_25(model=<class 'xgboost.sklearn.XGBClassifier'>, supports_nan=True, param_grid={'model__learning_rate': [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1], 'model__max_depth': [4, 5, 6, 7, 10, 12, None], 'model__n_estimators': [50, 100, 150, 200, 250], 'model__min_child_weight': [0.1, 0.25, 0.5, 0.75, 1, 1.5, 2, 2.5, 3], 'model__gamma': [0, 0.05, 0.1, 0.3, 0.4], 'model__scale_pos_weight': [1, 5, 10, 20, 25, 30, 35, 40]}, builtin_params={'enable_categorical': True}, search_n_iter=250, balancing_config=None, preprocessing=FunctionTransformer(func=<function preprocessing_for_xgboost.<locals>.convert_to_category at 0x7eff4b463eb0>), tunning_func_target=make_scorer(fbeta_score, beta=2.5, pos_label=1), best_params=None, ensemble_classifier=None)
In [71]:
importlib.reload(pipeline_runner)
importlib.reload(model_config)
importlib.reload(ml_config_core)
importlib.reload(stats_utils)
importlib.reload(model_defs)
Out[71]:
<module 'shared.definitions' from '/home/paulius/data/projects/health_m3_s2/shared/definitions.py'>
In [72]:
importlib.reload(pipeline_runner)
importlib.reload(model_config)
importlib.reload(ml_config_core)
importlib.reload(stats_utils)
importlib.reload(model_defs)

model_task_infos = model_config.get_config()
model_configs = {key: value.model for key, value in model_task_infos.items()}

model_configs_with_params = (
    model_defs.TuningResultsAPI.get_model_configs_with_hyperparams(
        model_configs, skip_missing=True
    )
)

cv = StratifiedKFold(n_splits=5, random_state=42, shuffle=True)

cv_results: Dict[str, ModelTrainingResult] = {}
for model_name, cfg in model_configs_with_params.items():
    start_time = time.time()

    result = pipeline_runner.run_pipeline_config(
        config=cfg,
        export_prod=True,
        cv=cv,
        features=features,
        labels=labels,
        export_test=True,
    )
    end_time = time.time()
    elapsed_time = round(end_time - start_time, 1)

    print(f"{model_name}: {elapsed_time} seconds")

    cv_results[model_name] = result

cv_results_df = pipeline_runner.build_cv_results_table(cv_results, VERBOSE=False)
cv_results_df = cv_results_df.sort_values(by="fbeta_2.5", ascending=False)
LGBMForestBaseConfigTuneFBeta_25: 1.8 seconds
Using balancing config: UnderSamplingConfig
Using balancing config: UnderSamplingConfig
Using balancing config: UnderSamplingConfig
XGBoostCatF1UndersampleAuto: 0.5 seconds
Using balancing config: SmoteConfig
Using balancing config: SmoteConfig
Using balancing config: SmoteConfig
Using balancing config: SmoteConfig
Ensemble_Log_KNN_SVM_SMOTE: 32.6 seconds
XGBoostTuneCatFBeta_25: 0.9 seconds
XGBoostTuneCatFBeta_325: 0.6 seconds
XGBoostTuneCatFBeta_40: 0.6 seconds
XGBoostTuneCatFBeta_50: 0.6 seconds
XGBoostTuneRecall: 0.5 seconds
CatBoostBaseConfigTuneFBeta_15: 1.2 seconds
CatBoostBaseConfigTuneFBeta_20: 0.5 seconds
CatBoostBaseConfigTuneFBeta_25: 1.6 seconds
CatBoostBaseConfigTuneFBeta_325: 0.6 seconds
CatBoostBaseConfigTuneFBeta_40: 0.5 seconds
CatBoostBaseConfigTuneRecall: 0.4 seconds
Results¶

The table below shows the results for each configuration using the optimal parameters:

In [73]:
cv_results_df
Out[73]:
accuracy precision_macro recall_macro f1_macro target_f1 target_recall target_precision fbeta_1.5 fbeta_2.5 fbeta_4.0 n_samples
XGBoostTuneCatFBeta_25 0.729 0.550 0.758 0.518 0.199 0.789 0.114 0.279 0.434 0.585 4908.0
XGBoostCatF1UndersampleAuto 0.712 0.548 0.754 0.508 0.191 0.799 0.109 0.270 0.426 0.582 4908.0
XGBoostTuneRecall 0.717 0.542 0.715 0.503 0.177 0.713 0.101 0.249 0.388 0.525 4908.0
Ensemble_Log_KNN_SVM_SMOTE 0.844 0.544 0.635 0.548 0.182 0.407 0.117 0.231 0.303 0.355 4908.0
XGBoostTuneCatFBeta_325 0.897 0.561 0.619 0.576 0.207 0.316 0.153 0.238 0.276 0.297 4908.0
XGBoostTuneCatFBeta_40 0.897 0.561 0.619 0.576 0.207 0.316 0.153 0.238 0.276 0.297 4908.0
XGBoostTuneCatFBeta_50 0.897 0.561 0.619 0.576 0.207 0.316 0.153 0.238 0.276 0.297 4908.0
CatBoostBaseConfigTuneFBeta_25 0.707 0.518 0.593 0.472 0.120 0.469 0.069 0.168 0.260 0.349 4908.0
CatBoostBaseConfigTuneFBeta_15 0.710 0.518 0.592 0.473 0.120 0.464 0.069 0.168 0.259 0.347 4908.0
CatBoostBaseConfigTuneFBeta_20 0.710 0.518 0.592 0.473 0.120 0.464 0.069 0.168 0.259 0.347 4908.0
LGBMForestBaseConfigTuneFBeta_25 0.330 0.511 0.557 0.281 0.093 0.804 0.049 0.141 0.258 0.423 4908.0
CatBoostBaseConfigTuneFBeta_325 0.360 0.510 0.554 0.299 0.092 0.766 0.049 0.140 0.254 0.412 4908.0
CatBoostBaseConfigTuneFBeta_40 0.360 0.510 0.554 0.299 0.092 0.766 0.049 0.140 0.254 0.412 4908.0
CatBoostBaseConfigTuneRecall 0.360 0.510 0.554 0.299 0.092 0.766 0.049 0.140 0.254 0.412 4908.0
In [74]:
# Ensemble Model Visualization Example
if VERBOSE:
    for model, res in cv_results.items():
        if res.ensemble_probas is not None:
            class1_1 = [pr[0, 0] for pr in res.ensemble_probas]
            class2_1 = [pr[0, 1] for pr in res.ensemble_probas]

            N = 4
            ind = np.arange(N)
            width = 0.35

            fig, ax = plt.subplots(figsize=(10, 7))

            p1 = ax.bar(
                ind,
                np.hstack(([class1_1[:-1], [0]])),
                width,
                color="green",
                edgecolor="k",
            )
            p2 = ax.bar(
                ind + width,
                np.hstack(([class2_1[:-1], [0]])),
                width,
                color="lightgreen",
                edgecolor="k",
            )

            # bars for VotingClassifier
            p3 = ax.bar(
                ind, [0, 0, 0, class1_1[-1]], width, color="blue", edgecolor="k"
            )
            p4 = ax.bar(
                ind + width,
                [0, 0, 0, class2_1[-1]],
                width,
                color="steelblue",
                edgecolor="k",
            )

            plt.axvline(2.8, color="k", linestyle="dashed")
            ax.set_xticks(ind + width)
            ax.set_xticklabels(
                [
                    "LogisticRegression\n",
                    "KNeighborsClassifier\n",
                    "SVC\n",
                    "VotingClassifier\n(average probabilities)",
                ],
                rotation=25,
                ha="right",
            )
            plt.ylim([0, 1])
            plt.title(f"Voting Classifier: {model}", fontsize=18, pad=20)
            plt.legend([p1[0], p2[0]], ["stroke = 0", "stroke = 1"], loc="upper left")
            plt.annotate(
                "*example for a single row ([0])",
                fontsize=12,
                xy=(0, -0.25),
                xycoords="axes fraction",
                xytext=(0, -40),
                textcoords="offset points",
                ha="left",
                va="top",
            )
            plt.tight_layout()
            plt.show()
In [75]:
confusion_matrices = {}

for k in cv_results_df.index:
    v = cv_results[k]
    confusion_matrices[k] = (v.cm_data, v.test_data)
importlib.reload(graph)
confusion_matrix_labels = ["No Stroke", "Stroke"]
confusion_matrix_axis_label = ""
In [76]:
importlib.reload(graph)

graph.roc_precision_recal_grid_plot(confusion_matrices, add_fbeta_25=True)
The figure layout has changed to tight
No description has been provided for this image

The behaviour of the precision-recall curve for all models indicates both very poor performance (precision is very low at all thresholds). Additionally, the curves are all:

  • non-monotonic, i.e., they change direction on the Y axis several times as the threshold is changed, due to fluctuating true and false positives.
  • precision quickly drops (even at very low thresholds) and varies significantly due to the model's inability to consistently identify the sparse positive cases in the heavily imbalanced dataset.
In [77]:
%matplotlib inline

importlib.reload(graph)

n = len(confusion_matrices)
columns = 2
rows = (n + 1) // columns
height = 8
width = height * columns

fig, axes = plt.subplots(
    rows, columns, figsize=(width, height * rows), constrained_layout=True
)
plt.suptitle("Confusion Matrices: Best Models based on f1", fontsize=20)

axes_flat = axes.flatten()
for i, (model_key, matrix_data) in enumerate(confusion_matrices.items()):
    graph.confusion_matrix_plot_v2(
        confusion_matrices[model_key][0],
        title=model_key,
        annotations=graph.make_annotations(confusion_matrices[model_key][1].metrics),
        ax=axes_flat[i],
    )

# Hide any unused axes
for j in range(i + 1, len(axes_flat)):
    axes_flat[j].axis("off")

plt.show()
No description has been provided for this image

Selecting the "Best" Model¶

We have been able to get relatively comparable results with all the complex boost model and our ensemble model performs relatively similarly as long as some oversampling technique like SMOTE is used. With additional tuning it might provide effectively the same performance as XGBoost or CatBoost. However, the training of the (LogisticRegression + KNeighborsClassifier + SVC) is very slow so it would still be much more practical to use complex model which handles balancing etc. directly.

As far as perfomance as hyperparemeter tunning goes the only parameter that really matters is class weight which directly affects the recall / precision ratio (based on our select fbeta value for scoring).

Having that mind we have selected: TODO as our production model, while it's overal performance is not ideal it still provides reasonable performances realtive to your assumptions outlayed previously. XX% recall relative to XX% precision means that for every person with stroke=1 we will also select ~{N} individuals as "high risk"

Model Feature Importance and SHAP plots¶

We'll use SHAP values to further analyze the importance of each feature:

In [78]:
cm_target_model_key = "XGBoostTuneCatFBeta_25"
cm_target_model: ModelTrainingResult = cv_results[cm_target_model_key]

cm_alt_target_model_key = "XGBoostTuneCatFBeta_325"
cm_alt_target_model: ModelTrainingResult = cv_results[cm_alt_target_model_key]
In [79]:
for model_key in [cm_target_model_key]:
    config: ModelTrainingResult = cv_results[model_key]

    try:
        graph.render_feature_importances_chart(
            feature_importances=config.feature_importances,
            title=f"{model_key} Importances",
        )
    except:
        pass

    model = config.test_data.test_model.named_steps["model"]
    x_test = config.test_data.x_test
    x_train = config.test_data.x_train
    y_test = config.test_data.y_test

    if "Linear" in model_key:
        explainer = shap.LinearExplainer(model, x_train)
    elif "SVM" in model_key:
        continue
        X_background_sampled = shap.sample(x_train, 100)
        explainer = shap.KernelExplainer(model.predict, X_background_sampled)

    elif "Boost" in model_key:
        booster_model = model.get_booster()
        shap_values = booster_model.predict(
            xgboost.DMatrix(x_test, enable_categorical=True), pred_contribs=True
        )

        shap_values = shap_values[:, :-1]
        shap.summary_plot(
            shap_values,
            features=x_test,
            feature_names=x_test.columns,
            plot_size=(12, 8),
        )

    else:
        try:
            explainer = shap.Explainer(model)
        except:
            continue

    try:
        shap_values = explainer.shap_values(x_train)
        shap.summary_plot(shap_values, x_train, plot_size=(32, 7))
    except:
        pass
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
No description has been provided for this image
No description has been provided for this image

Probability Thresholds¶

An approach that might mitigate the precision / recall issue is to further split the risk group identifed by our model into separate "Low", "Medium", "High" risk categories which would allow us to more effectively use the resources by giving more focus ot individuals who have the highest risk:

In [80]:
target_cm_data = cm_target_model.cm_data

importlib.reload(stats_utils)
importlib.reload(ml_config_core)
importlib.reload(graph)
summary_desc = "The chart shows the  performance of the if only individual with stroke Prob. > T are selected. Additionally the overlay indicates the number of people whose predicted P is in an given range. The overlays can be used to selected the most at risk individual based on the probability predicted for them"

graph.plot_threshold_metrics_v2(
    target_cm_data,
    0,
    1,
    model_name=cm_target_model_key,
    class_pos=1,
    include_vars=["f1", "precision", "recall"],
    show_threshold_n=True,
)
display(Markdown(summary_desc))
Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
No description has been provided for this image

The chart shows the performance of the if only individual with stroke Prob. > T are selected. Additionally the overlay indicates the number of people whose predicted P is in an given range. The overlays can be used to selected the most at risk individual based on the probability predicted for them

Conclusion¶

  • We have tried multiple different ML models to predict the insurance columns
  • While the overall performance is reasonable good (F1 > 0.8) the model underestimates the TravelInsurance = True class
    • This is a big issue for our client because we can only identify around 60% of all potential clients.
    • On the positive our model is very good at identifying people who don't need travel insurance (almost 95% in the best case) which means that we can only contact the people who are likely to buy it which results in very high efficiency of our sales team.

Limitations and Suggestions for Future Improvements:¶

Business Case/Interpretation¶
  • A deeper cost based analysis should be performed (ideally including based on data from specific insurance companies/government healthcare systems/etc.) to determine the acceptable precision/recall ratio. While the direct and indirect cost of an individual suffering a stroke might be high:
    • It's not clear what real benefits identifying individual stroke victims provides. If it's mostly related to lifestyle choices additional treatment and monitoring would not be particularly useful if the patients are unwilling to alter their lifestyles.
    • Potentially this model can be used on an app targeting consumers for self identification purposes (i.e. to alter lifestyle choices)
Technical¶
  • Tunning for 'log_loss' instead of a classification metric.
  • Tweaking the threshold and using that while hyper-parameter tuning might be beneficial: -
  • Using AUCPR for tuning
  • Over-fitting hyper-parameters like 'early_stopping_rounds' can be utilized to cut model training early {TODO}